{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f49b3474",
   "metadata": {},
   "source": [
    "### This notebook includes utility functions for Flash."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e537a18-b3ec-48fb-acd3-76619bdc7ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Split_logs_into_chunks(df, r, start_id):\n",
    "    \"\"\"\n",
    "    Splits a DataFrame of logs into chunks.\n",
    "\n",
    "    This function divides a DataFrame into smaller chunks, each containing 'r' rows. \n",
    "    Each chunk is assigned a unique graph number starting from 'start_id'. The graph \n",
    "    number is incremented after every 'r' rows.\n",
    "\n",
    "    Parameters:\n",
    "    df (pandas.DataFrame): The DataFrame containing log data.\n",
    "    r (int): The number of rows in each chunk.\n",
    "    start_id (int): The starting number for graph numbering.\n",
    "\n",
    "    Returns:\n",
    "    pandas.DataFrame: The modified DataFrame with an additional column 'graph_no' indicating \n",
    "                      the graph number for each row.\n",
    "    \"\"\"\n",
    "    total_rows = len(df)\n",
    "    df['graph_no'] = -1\n",
    "    current_graph_no = start_id\n",
    "    \n",
    "    graph_ids = []\n",
    "    for i in range(total_rows):\n",
    "        graph_ids.append(current_graph_no)\n",
    "        if (i + 1) % r == 0:\n",
    "            current_graph_no += 1\n",
    "            \n",
    "    df.iloc[:len(graph_ids)]['graph_no'] = graph_ids\n",
    "    df[df['graph_no'] == -1] = current_graph_no\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c4175bb-07f2-42b2-bc64-f43e6de59f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Modify_Structure(attack_df, benign_df, benign_structure_count):\n",
    "    \"\"\"\n",
    "    Combines attack data with a subset of benign events.\n",
    "\n",
    "    This function selects a specific number of unique benign processes from the benign \n",
    "    DataFrame and combines them with the attack DataFrame. Additionally, it creates \n",
    "    new events by pairing each malicious process with the selected benign processes \n",
    "    and appends these to the combined DataFrame.\n",
    "\n",
    "    Parameters:\n",
    "    attack_df (pandas.DataFrame): DataFrame containing attack data.\n",
    "    benign_df (pandas.DataFrame): DataFrame containing benign data.\n",
    "    benign_structure_count (int): The number of unique benign processes to select.\n",
    "\n",
    "    Returns:\n",
    "    pandas.DataFrame: A combined DataFrame containing both attack and selected benign data,\n",
    "                      along with newly created events.\n",
    "    \"\"\"\n",
    "    \n",
    "    unique_benign_procs = list(set(benign_df['actorID']))\n",
    "    selected_benign_procs = unique_benign_procs[:benign_structure_count]\n",
    "    selected_benign_df = benign_df[benign_df['actorID'].isin(selected_benign_procs)]\n",
    "\n",
    "    combined_df = pd.concat([attack_df, selected_benign_df], ignore_index=True)\n",
    "                      \n",
    "    event_template = {'actorID': None, 'objectID': None, 'action': '', 'timestamp': '', 'exec': '', 'path': ''}\n",
    "    \n",
    "    for malicious_proc in GT_mal:\n",
    "        for benign_proc in selected_benign_procs:\n",
    "            new_event = event_template.copy()\n",
    "            new_event['actorID'] = malicious_proc\n",
    "            new_event['objectID'] = benign_proc\n",
    "            combined_df = combined_df.append(new_event, ignore_index=True)\n",
    "    \n",
    "    return combined_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e2f52c-9b32-4f21-8a0b-7672a418fb58",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_xgb_model(word_encodings, gnn_encodings, labels, model_filename):\n",
    "    \"\"\"\n",
    "    Trains an XGBoost classifier using word and GNN encodings and saves the model.\n",
    "\n",
    "    This function trains an XGBoost classifier using provided word encodings and \n",
    "    graph neural network (GNN) encodings. It then saves the trained model to a file \n",
    "    and calculates the model's accuracy.\n",
    "\n",
    "    Parameters:\n",
    "    word_encodings (numpy.ndarray): The word encodings as a NumPy array.\n",
    "    gnn_encodings (numpy.ndarray): The GNN encodings as a NumPy array.\n",
    "    labels (numpy.ndarray): The labels for the training data.\n",
    "    model_filename (str): The filename where the trained model will be saved.\n",
    "\n",
    "    Returns:\n",
    "    tuple: A tuple containing the trained XGBoost classifier and its accuracy.\n",
    "    \"\"\"\n",
    "\n",
    "    x = np.hstack((word_encodings, gnn_encodings))\n",
    "    y = labels\n",
    "\n",
    "    xgb_cl = xgb.XGBClassifier()\n",
    "    xgb_cl.fit(x, y)\n",
    "\n",
    "    pickle.dump(xgb_cl, open(model_filename, \"wb\"))\n",
    "\n",
    "    preds = xgb_cl.predict(x)\n",
    "    accuracy = accuracy_score(y, preds)\n",
    "\n",
    "    return xgb_cl, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c1e362-9528-47fd-9dcf-ad0b98b8d37d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_gnn(df, model, optimizer, device, num_epochs):\n",
    "    \"\"\"\n",
    "    Trains a graph neural network (GNN) model using the provided dataset.\n",
    "\n",
    "    This function prepares a graph from the given dataframe, trains the GNN model for \n",
    "    a specified number of epochs, and computes the encoding of the nodes after training.\n",
    "    It uses CrossEntropyLoss with class weights for imbalance handling and calculates the \n",
    "    accuracy of the model.\n",
    "\n",
    "    Parameters:\n",
    "    df (pandas.DataFrame): The dataframe containing the dataset for training.\n",
    "    model: The GNN model to be trained.\n",
    "    optimizer: The optimizer used for training the model.\n",
    "    device: The device (CPU/GPU) on which the training is performed.\n",
    "    num_epochs (int): The number of epochs for training the model.\n",
    "\n",
    "    Returns:\n",
    "    tuple: A tuple containing the trained GNN model and the node encodings as a NumPy array.\n",
    "    \"\"\"\n",
    "\n",
    "    phrases, labels, edges, mapp = prepare_graph(df)\n",
    "    word_encodings = [infer(x) for x in phrases]\n",
    "    word_encodings = np.array(word_encodings)\n",
    "\n",
    "    graph = Data(x=torch.tensor(word_encodings, dtype=torch.float).to(device),\n",
    "                 y=torch.tensor(labels, dtype=torch.long).to(device),\n",
    "                 edge_index=torch.tensor(edges, dtype=torch.long).to(device))\n",
    "\n",
    "    class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(labels), y=labels)\n",
    "    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)\n",
    "    criterion = CrossEntropyLoss(weight=class_weights, reduction='mean')\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        out = model(graph.x, graph.edge_index)\n",
    "        loss = criterion(out, graph.y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            out = model(graph.x, graph.edge_index)\n",
    "            sorted, indices = out.sort(dim=1, descending=True)\n",
    "            conf = (sorted[:, 0] - sorted[:, 1]) / sorted[:, 0]\n",
    "            conf = (conf - conf.min()) / conf.max()\n",
    "            pred = indices[:, 0]\n",
    "            cond = pred == graph.y\n",
    "            accuracy = cond.sum().item() / len(graph.y)\n",
    "\n",
    "    model.eval()\n",
    "    gnn_encodings = model.encode(graph.x, graph.edge_index).detach().cpu().numpy()\n",
    "\n",
    "    return model, gnn_encodings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95176f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Validate(nodes, labels, edges, model):\n",
    "    \"\"\"\n",
    "    Validates a graph using the provided Gnn model and calculates confidence scores for each node.\n",
    "\n",
    "    Parameters:\n",
    "    nodes (List[List[float]]): Node features for the graph.\n",
    "    labels (List[int]): labels for each node in the graph.\n",
    "    edges (List[List[int]]): Edge indices of the graph.\n",
    "    model (torch.nn.Module): The PyTorch model to be used for evaluation.\n",
    "\n",
    "    Returns:\n",
    "    List[float]: A list of confidence scores for each node in the graph where the model's prediction was incorrect.\n",
    "    \"\"\"\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    graph = Data(x=torch.tensor(nodes, dtype=torch.float).to(device),\n",
    "                 y=torch.tensor(labels, dtype=torch.long).to(device),\n",
    "                 edge_index=torch.tensor(edges, dtype=torch.long).to(device))\n",
    "    \n",
    "    model.eval()\n",
    "    out = model(graph.x, graph.edge_index)\n",
    "\n",
    "    sorted, indices = out.sort(dim=1, descending=True)\n",
    "    conf = (sorted[:, 0] - sorted[:, 1]) / sorted[:, 0]\n",
    "    conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    pred = indices[:, 0]\n",
    "    cond = (pred == graph.y)\n",
    "    flag = ~torch.tensor(cond)\n",
    "\n",
    "    return conf[flag].tolist()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyTorch 1.12.0",
   "language": "python",
   "name": "pytorch-1.12.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
